# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for mcore RoPE UT of inference."""
import numpy as np


KV_CHANNELS = 32
NUM_HEAD = 1
ROTARY_PERCENT = 1.0
MAX_POSITION_EMBEDDING = 1024


def get_init_params(batch_size, seq_length, kv_channels):
    """Generate initialization parameters"""
    np.random.seed(2025)
    shape_for_prefill = (batch_size * seq_length, kv_channels * NUM_HEAD)
    shape_for_decode = (batch_size * 1, kv_channels * NUM_HEAD)
    return {
        "query_for_prefill": np.random.normal(loc=0, scale=0.01, size=shape_for_prefill),
        "key_for_prefill": np.random.normal(loc=0, scale=0.01, size=shape_for_prefill),
        "query_for_decode": np.random.normal(loc=0, scale=0.01, size=shape_for_decode),
        "key_for_decode": np.random.normal(loc=0, scale=0.01, size=shape_for_decode),
    }


def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    rope_q_emb1_for_prefill = np.array(
        [[[-9.23901796e-04, 7.34285591e-03, -1.43888202e-02,
           -6.63422048e-03, -1.00728089e-03, 2.14698650e-02,
           1.39370812e-02, -2.07371125e-03, 7.63587281e-03,
           2.05053645e-03, -1.04149617e-02, -8.28076061e-03,
           -1.05379568e-02, -8.87586619e-04, 1.58193358e-03,
           -1.78932305e-02, 1.06348181e-02, 3.71200062e-04,
           -1.63757929e-03, -3.28207575e-03, 1.58897657e-02,
           -1.88783021e-03, -1.81722967e-03, 7.17900228e-03,
           -7.07250865e-06, 7.82166980e-03, 8.69911222e-04,
           6.62555685e-04, 1.20056495e-02, -6.16653124e-04,
           -9.00906324e-03, -1.25924805e-02],
          [5.99055039e-03, -1.38755748e-02, 7.45566655e-03,
           5.20251971e-03, -2.42165476e-03, -5.81937097e-03,
           -1.01434893e-03, 5.70974918e-03, 8.67953897e-03,
           -1.19327987e-02, 1.90028343e-02, -6.43109111e-03,
           -1.18313320e-02, 1.55276973e-02, 1.21676121e-02,
           9.36487690e-03, -3.17346607e-03, 3.89497355e-03,
           1.28487998e-03, -1.32205570e-02, -6.99169608e-03,
           -1.29683232e-02, 8.99381470e-03, 1.18395407e-02,
           -1.03066331e-02, -1.63140660e-03, 4.00607521e-03,
           -1.16742002e-02, 9.35880188e-03, -1.00374334e-02,
           -1.39038442e-02, -8.46103486e-03],
          [-3.72082670e-03, -8.59507546e-03, 1.20599866e-02,
           1.88908107e-05, -1.42772831e-02, 3.43290414e-03,
           -1.43426992e-02, 1.53077245e-02, -1.20205143e-02,
           -4.45510959e-03, 1.91527209e-03, -7.78383156e-03,
           -9.85337514e-03, 2.52749771e-03, -1.56111876e-02,
           1.04025351e-02, -1.29258167e-02, -1.25538278e-02,
           -1.44812127e-03, -2.21166927e-02, -1.78887552e-04,
           -7.29634427e-03, 3.33205494e-03, -1.08462316e-03,
           8.99155159e-03, -8.99989158e-03, 7.95922987e-03,
           -2.33992505e-05, -5.18175308e-03, 2.13852036e-03,
           -7.76226446e-03, -5.92254195e-03],
          [-1.52982026e-02, -2.72572273e-03, -4.52261744e-03,
           -2.08162479e-02, 1.53162312e-02, 1.74533762e-02,
           -5.69273718e-03, 8.48765578e-03, 5.49183320e-03,
           -4.22995351e-03, -2.07138131e-03, 1.06552066e-02,
           -1.60846617e-02, 8.15861952e-03, 1.24731511e-02,
           -6.40225224e-03, 3.41185322e-03, 7.65979569e-03,
           1.00506218e-02, -4.69959201e-03, -8.86818208e-03,
           -2.78789271e-03, -2.27947868e-02, -8.15463625e-03,
           6.69314712e-03, -8.55318457e-03, 3.78735829e-03,
           2.02723481e-02, -5.40052773e-03, 3.31624458e-03,
           -1.54345119e-02, 9.38425120e-03]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    rope_k_emb1_for_prefill = np.array(
        [[[0.00934116, -0.00655473, -0.0096474, -0.00440213,
           -0.01250903, 0.00543782, 0.01050946, 0.00117965,
           -0.01808849, 0.01594066, 0.00479259, 0.00231449,
           -0.01611635, -0.00113272, 0.00043416, 0.00104601,
           0.00577569, -0.00843591, -0.00485385, -0.00759943,
           -0.00229918, -0.01825233, 0.00956456, 0.00470442,
           0.00755409, -0.00162371, 0.00283398, -0.00499728,
           0.0157085, -0.0125403, -0.01369309, -0.00135164],
          [-0.00600774, 0.01318087, 0.00751857, 0.00816627,
           -0.00809037, -0.0096536, -0.01047112, 0.00786659,
           0.00933447, 0.00877941, -0.00952378, -0.01436589,
           0.00077368, 0.00465401, 0.00202908, 0.00769578,
           0.0058117, 0.02309195, 0.01068325, -0.01499643,
           -0.00769764, -0.00282934, 0.0180698, -0.00686065,
           -0.00306341, -0.00223855, 0.00699147, -0.00324152,
           0.0027718, -0.01390233, 0.00090867, 0.01171165],
          [0.00566894, -0.00231632, 0.0085875, 0.00842386,
           0.00405598, 0.0068562, 0.01049533, -0.00729952,
           -0.00609996, 0.00648784, 0.01709718, -0.00112242,
           -0.01035759, 0.01223356, -0.00321697, -0.00406236,
           -0.02380789, -0.00136414, 0.00595801, -0.01012689,
           0.00879869, -0.01051155, -0.0064062, -0.00993737,
           -0.00937491, -0.01533977, -0.00393356, -0.00464522,
           -0.0035962, 0.00135095, -0.01737823, 0.00420752],
          [-0.00157997, -0.00619276, 0.00311472, -0.00104023,
           0.01632158, 0.01096822, 0.00777623, -0.01417324,
           -0.0004948, 0.00651393, 0.00911529, -0.00096117,
           0.01968941, 0.00432763, -0.00636199, -0.00672211,
           0.00716637, 0.00561806, 0.00695412, -0.0019441,
           0.012666, -0.01404032, -0.00566378, 0.00137474,
           0.00854271, 0.0047151, -0.01046264, 0.01110856,
           0.01529155, 0.01786018, 0.00347885, 0.02156327]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    rope_q_emb1_for_decode = np.array(
        [[[0.00018698, -0.00104761, -0.00765957, -0.00298842,
           -0.01221701, 0.00645509, 0.0102018, 0.00109581,
           -0.01816313, 0.01594953, 0.00478361, 0.00232337,
           -0.01613205, -0.00112567, 0.00043849, 0.00104625,
           0.01098093, -0.01063163, -0.00761335, -0.00825829,
           -0.00353651, -0.01791784, 0.00989206, 0.00472465,
           0.00737284, -0.00153405, 0.00284912, -0.00499316,
           0.01569237, -0.01254094, -0.01369295, -0.00135146]],

         [[-0.00600774, 0.01318087, 0.00751857, 0.00816627,
           -0.00809037, -0.0096536, -0.01047112, 0.00786659,
           0.00933447, 0.00877941, -0.00952378, -0.01436589,
           0.00077368, 0.00465401, 0.00202908, 0.00769578,
           0.0058117, 0.02309195, 0.01068325, -0.01499643,
           -0.00769764, -0.00282934, 0.0180698, -0.00686065,
           -0.00306341, -0.00223855, 0.00699147, -0.00324152,
           0.0027718, -0.01390233, 0.00090867, 0.01171165]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    rope_k_emb1_for_decode = np.array(
        [[[0.02309659, -0.00123231, 0.00630885, 0.01008239,
           0.00315732, 0.00743616, 0.01069263, -0.00712166,
           -0.00600591, 0.006574, 0.01710953, -0.00111416,
           -0.01035399, 0.0122328, -0.00321147, -0.0040631,
           -0.00809321, -0.00238907, 0.00833315, -0.00847708,
           0.00915966, -0.01010958, -0.00607116, -0.0100656,
           -0.00943544, -0.01530304, -0.00387947, -0.0046472,
           -0.00360655, 0.00135783, -0.01737924, 0.00420679]],

         [[-0.00157997, -0.00619276, 0.00311472, -0.00104023,
           0.01632158, 0.01096822, 0.00777623, -0.01417324,
           -0.0004948, 0.00651393, 0.00911529, -0.00096117,
           0.01968941, 0.00432763, -0.00636199, -0.00672211,
           0.00716637, 0.00561806, 0.00695412, -0.0019441,
           0.012666, -0.01404032, -0.00566378, 0.00137474,
           0.00854271, 0.0047151, -0.01046264, 0.01110856,
           0.01529155, 0.01786018, 0.00347885, 0.02156327]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    llama3_rope_q_emb1_for_prefill = np.array(
        [[[-9.23901796e-04, 7.34285591e-03, -1.43888202e-02,
           -6.63422048e-03, -1.00728089e-03, 2.14698650e-02,
           1.39370812e-02, -2.07371125e-03, 7.63587281e-03,
           2.05053645e-03, -1.04149617e-02, -8.28076061e-03,
           -1.05379568e-02, -8.87586619e-04, 1.58193358e-03,
           -1.78932305e-02, 1.06348181e-02, 3.71200062e-04,
           -1.63757929e-03, -3.28207575e-03, 1.58897657e-02,
           -1.88783021e-03, -1.81722967e-03, 7.17900228e-03,
           -7.07250865e-06, 7.82166980e-03, 8.69911222e-04,
           6.62555685e-04, 1.20056495e-02, -6.16653124e-04,
           -9.00906324e-03, -1.25924805e-02],
          [5.99055039e-03, -1.38755748e-02, 7.45566655e-03,
           5.20251971e-03, -2.42165476e-03, -5.81937097e-03,
           -8.96211248e-04, 5.86589100e-03, 8.58902466e-03,
           -1.19406814e-02, 1.90138463e-02, -6.44924818e-03,
           -1.18231382e-02, 1.55227566e-02, 1.21637648e-02,
           9.36356001e-03, -3.17346607e-03, 3.89497355e-03,
           1.28487998e-03, -1.32205570e-02, -6.99169608e-03,
           -1.29683232e-02, 9.00635403e-03, 1.17629627e-02,
           -1.03821829e-02, -1.57267193e-03, 3.95347923e-03,
           -1.16641792e-02, 9.36915074e-03, -1.00450721e-02,
           -1.39072109e-02, -8.46249238e-03],
          [-3.72082670e-03, -8.59507546e-03, 1.20599866e-02,
           1.88908107e-05, -1.42772831e-02, 3.43290414e-03,
           -1.43426992e-02, 1.53077245e-02, -1.20205143e-02,
           -4.45510959e-03, 1.91527209e-03, -7.78383156e-03,
           -9.85337514e-03, 2.52749771e-03, -1.56111876e-02,
           1.04025351e-02, -1.29258167e-02, -1.25538278e-02,
           -1.44812127e-03, -2.21166927e-02, -1.78887552e-04,
           -7.29634427e-03, 3.33205494e-03, -1.08462316e-03,
           8.99155159e-03, -8.99989158e-03, 7.95922987e-03,
           -2.33992505e-05, -5.18175308e-03, 2.13852036e-03,
           -7.76226446e-03, -5.92254195e-03],
          [-1.52982026e-02, -2.72572273e-03, -4.52261744e-03,
           -2.08162479e-02, 1.53162312e-02, 1.74533762e-02,
           -5.99144446e-03, 8.37902352e-03, 5.55018755e-03,
           -4.27198783e-03, -2.06089369e-03, 1.06867375e-02,
           -1.60893817e-02, 8.16025119e-03, 1.24688800e-02,
           -6.40079193e-03, 3.41185322e-03, 7.65979569e-03,
           1.00506218e-02, -4.69959201e-03, -8.86818208e-03,
           -2.78789271e-03, -2.27181017e-02, -8.26621801e-03,
           6.64483802e-03, -8.53226800e-03, 3.79307545e-03,
           2.02557445e-02, -5.38645172e-03, 3.31222988e-03,
           -1.54379634e-02, 9.38524678e-03]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    llama3_rope_k_emb1_for_prefill = np.array(
        [[[0.00934116, -0.00655473, -0.0096474, -0.00440213,
           -0.01250903, 0.00543782, 0.01050946, 0.00117965,
           -0.01808849, 0.01594066, 0.00479259, 0.00231449,
           -0.01611635, -0.00113272, 0.00043416, 0.00104601,
           0.00577569, -0.00843591, -0.00485385, -0.00759943,
           -0.00229918, -0.01825233, 0.00956456, 0.00470442,
           0.00755409, -0.00162371, 0.00283398, -0.00499728,
           0.0157085, -0.0125403, -0.01369309, -0.00135164],
          [-0.00600774, 0.01318087, 0.00751857, 0.00816627,
           -0.00809037, -0.0096536, -0.01023304, 0.00777513,
           0.00930731, 0.00876829, -0.0095044, -0.01437091,
           0.0007761, 0.00464717, 0.00202933, 0.0076976,
           0.0058117, 0.02309195, 0.01068325, -0.01499643,
           -0.00769764, -0.00282934, 0.01820568, -0.00696412,
           -0.00314497, -0.00228172, 0.0070178, -0.00321916,
           0.00277112, -0.01390462, 0.00090811, 0.01171045],
          [0.00566894, -0.00231632, 0.0085875, 0.00842386,
           0.00405598, 0.0068562, 0.01049533, -0.00729952,
           -0.00609996, 0.00648784, 0.01709718, -0.00112242,
           -0.01035759, 0.01223356, -0.00321697, -0.00406236,
           -0.02380789, -0.00136414, 0.00595801, -0.01012689,
           0.00879869, -0.01051155, -0.0064062, -0.00993737,
           -0.00937491, -0.01533977, -0.00393356, -0.00464522,
           -0.0035962, 0.00135095, -0.01737823, 0.00420752],
          [-0.00157997, -0.00619276, 0.00311472, -0.00104023,
           0.01632158, 0.01096822, 0.00770122, -0.01415381,
           -0.00042003, 0.00653705, 0.0090863, -0.00094388,
           0.01970278, 0.00433641, -0.00636102, -0.00671875,
           0.00716637, 0.00561806, 0.00695412, -0.0019441,
           0.012666, -0.01404032, -0.00576536, 0.00156213,
           0.00854672, 0.004683, -0.01048782, 0.01111004,
           0.01527432, 0.01785805, 0.00348061, 0.02156432]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    llama3_rope_q_emb1_for_decode = np.array(
        [[[0.00018698, -0.00104761, -0.00765957, -0.00298842,
           -0.01221701, 0.00645509, 0.01033076, 0.00115822,
           -0.01809792, 0.01594179, 0.00479147, 0.0023156,
           -0.01611832, -0.00113184, 0.0004347, 0.00104604,
           0.01098093, -0.01063163, -0.00761335, -0.00825829,
           -0.00353651, -0.01791784, 0.0097573, 0.00470974,
           0.00753148, -0.00161251, 0.00283587, -0.00499676,
           0.01570648, -0.01254038, -0.01369308, -0.00135162]],

         [[-0.00600774, 0.01318087, 0.00751857, 0.00816627,
           -0.00809037, -0.0096536, -0.01023304, 0.00777513,
           0.00930731, 0.00876829, -0.0095044, -0.01437091,
           0.0007761, 0.00464717, 0.00202933, 0.0076976,
           0.0058117, 0.02309195, 0.01068325, -0.01499643,
           -0.00769764, -0.00282934, 0.01820568, -0.00696412,
           -0.00314497, -0.00228172, 0.0070178, -0.00321916,
           0.00277112, -0.01390462, 0.00090811, 0.01171045]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    llama3_rope_k_emb1_for_decode = np.array(
        [[[0.02309659, -0.00123231, 0.00630885, 0.01008239,
           0.00315732, 0.00743616, 0.01061202, -0.00725421,
           -0.00608823, 0.00649862, 0.01709873, -0.00112139,
           -0.01035714, 0.01223347, -0.00321628, -0.00406245,
           -0.00809321, -0.00238907, 0.00833315, -0.00847708,
           0.00915966, -0.01010958, -0.00621098, -0.0099705,
           -0.00938252, -0.01533521, -0.0039268, -0.00464547,
           -0.00359749, 0.00135181, -0.01737836, 0.00420743]],

         [[-0.00157997, -0.00619276, 0.00311472, -0.00104023,
           0.01632158, 0.01096822, 0.00770122, -0.01415381,
           -0.00042003, 0.00653705, 0.0090863, -0.00094388,
           0.01970278, 0.00433641, -0.00636102, -0.00671875,
           0.00716637, 0.00561806, 0.00695412, -0.0019441,
           0.012666, -0.01404032, -0.00576536, 0.00156213,
           0.00854672, 0.004683, -0.01048782, 0.01111004,
           0.01527432, 0.01785805, 0.00348061, 0.02156432]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    yarn_rope_q_emb1_for_prefill = np.array(
        [[[-1.11602177e-03, 8.86975974e-03, -1.73808914e-02,
           -8.01376812e-03, -1.21673907e-03, 2.59343982e-02,
           1.68352164e-02, -2.50492734e-03, 9.22370795e-03,
           2.47693341e-03, -1.25806918e-02, -1.00026960e-02,
           -1.27292629e-02, -1.07215508e-03, 1.91088743e-03,
           -2.16140226e-02, 1.28462659e-02, 4.48388950e-04,
           -1.97810424e-03, -3.96456430e-03, 1.91939492e-02,
           -2.28039338e-03, -2.19511194e-03, 8.67183413e-03,
           -8.54319569e-06, 9.44813993e-03, 1.05080416e-03,
           8.00330250e-04, 1.45021537e-02, -7.44882564e-04,
           -1.08824456e-02, -1.52110131e-02],
          [7.23625021e-03, -1.67609192e-02, 9.06032789e-03,
           5.56841306e-03, -3.23980255e-03, -7.46709295e-03,
           -1.01033237e-03, 7.08717946e-03, 1.03750620e-02,
           -1.44236777e-02, 2.29676627e-02, -7.79033126e-03,
           -1.42816901e-02, 1.87506229e-02, 1.46931484e-02,
           1.13106575e-02, -3.83337005e-03, 4.70491173e-03,
           1.19494891e-03, -1.62332095e-02, -8.32997169e-03,
           -1.54611962e-02, 1.08861178e-02, 1.42082497e-02,
           -1.25410976e-02, -1.89969991e-03, 4.77558188e-03,
           -1.40896766e-02, 1.13174105e-02, -1.21338870e-02,
           -1.67991333e-02, -1.02222180e-02],
          [-4.49455064e-03, -1.03823710e-02, 1.45677906e-02,
           2.28190438e-05, -1.72461607e-02, 4.14675660e-03,
           -1.73251797e-02, 1.84908770e-02, -1.45201096e-02,
           -5.38152363e-03, 2.31354171e-03, -9.40243341e-03,
           -1.19023267e-02, 3.05307610e-03, -1.88574418e-02,
           1.25656817e-02, -1.56136649e-02, -1.51643232e-02,
           -1.74924964e-03, -2.67157294e-02, -2.16086177e-04,
           -8.81357677e-03, 4.02493635e-03, -1.31016423e-03,
           1.08612925e-02, -1.08713666e-02, 9.61430557e-03,
           -2.82649871e-05, -6.25926815e-03, 2.58321315e-03,
           -9.37638152e-03, -7.15410011e-03],
          [-1.84793733e-02, -3.29252076e-03, -4.97902790e-03,
           -2.53723133e-02, 1.80865284e-02, 2.09796969e-02,
           -7.41933426e-03, 1.01203313e-02, 6.70431601e-03,
           -5.16032288e-03, -2.48944433e-03, 1.29089821e-02,
           -1.94350742e-02, 9.85712744e-03, 1.50617119e-02,
           -7.73179904e-03, 4.12132777e-03, 9.25260596e-03,
           1.23469960e-02, -4.55373898e-03, -1.13983685e-02,
           -3.95899359e-03, -2.73935497e-02, -9.98620596e-03,
           8.02659336e-03, -1.03065027e-02, 4.58182301e-03,
           2.44678073e-02, -6.50653290e-03, 4.00098879e-03,
           -1.86481960e-02, 1.13368537e-02]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    yarn_rope_k_emb1_for_prefill = np.array(
        [[[0.0112836, -0.00791775, -0.01165351, -0.00531752,
           -0.01511021, 0.00656859, 0.01269485, 0.00142495,
           -0.02184989, 0.01925542, 0.00578918, 0.00279577,
           -0.01946765, -0.00136827, 0.00052444, 0.00126352,
           0.00697671, -0.01019011, -0.00586318, -0.00917969,
           -0.00277728, -0.02204779, 0.01155345, 0.00568268,
           0.00912492, -0.00196136, 0.00342328, -0.00603643,
           0.01897499, -0.01514799, -0.01654049, -0.00163271],
          [-0.00725701, 0.01592175, 0.00958489, 0.00904959,
           -0.01011446, -0.01175249, -0.01221469, 0.00939103,
           0.01124271, 0.0105916, -0.01148078, -0.01735926,
           0.00093749, 0.00561352, 0.00245132, 0.00929827,
           0.00702021, 0.02789378, 0.01253578, -0.01853535,
           -0.00892539, -0.0030885, 0.02207302, -0.00841327,
           -0.00379895, -0.00275619, 0.00847711, -0.00388857,
           0.00334736, -0.016796, 0.00109695, 0.01414557],
          [0.00684776, -0.00279799, 0.01037322, 0.01017555,
           0.0048994, 0.00828191, 0.01267777, -0.00881741,
           -0.00736841, 0.00783695, 0.02065243, -0.00135582,
           -0.01251139, 0.01477746, -0.00388592, -0.0049071,
           -0.0287586, -0.00164781, 0.00719694, -0.01223272,
           0.01062833, -0.01269736, -0.00773833, -0.01200379,
           -0.01132436, -0.01852959, -0.00475152, -0.00561116,
           -0.00434401, 0.00163188, -0.02099193, 0.00508244],
          [-0.00190852, -0.0074805, 0.00409142, -0.00135967,
           0.0202753, 0.01276697, 0.00925621, -0.01709681,
           -0.00050737, 0.00789639, 0.01097574, -0.00114015,
           0.02379986, 0.00523815, -0.00768376, -0.00811588,
           0.00865658, 0.0067863, 0.00824494, -0.0022902,
           0.0145499, -0.01732569, -0.00702583, 0.00188879,
           0.01032395, 0.0056568, -0.0126687, 0.01342031,
           0.01845052, 0.02157152, 0.00420438, 0.02604849]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    yarn_rope_q_emb1_for_decode = np.array(
        [[[0.00022586, -0.00126545, -0.00960853, -0.00404962,
           -0.01490725, 0.00718583, 0.01255695, 0.00139967,
           -0.02186128, 0.01925679, 0.00578783, 0.00279711,
           -0.01947002, -0.0013672, 0.00052509, 0.00126356,
           0.01326435, -0.01284242, -0.00882369, -0.00980527,
           -0.00371563, -0.02185441, 0.01170318, 0.00568896,
           0.00909761, -0.00194782, 0.00342557, -0.00603581,
           0.01897256, -0.01514808, -0.01654047, -0.00163268]],

         [[-0.00725701, 0.01592175, 0.00958489, 0.00904959,
           -0.01011446, -0.01175249, -0.01221469, 0.00939103,
           0.01124271, 0.0105916, -0.01148078, -0.01735926,
           0.00093749, 0.00561352, 0.00245132, 0.00929827,
           0.00702021, 0.02789378, 0.01253578, -0.01853535,
           -0.00892539, -0.0030885, 0.02207302, -0.00841327,
           -0.00379895, -0.00275619, 0.00847711, -0.00388857,
           0.00334736, -0.016796, 0.00109695, 0.01414557]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    yarn_rope_k_emb1_for_decode = np.array(
        [[[0.02789939, -0.00148857, 0.00801257, 0.01171184,
           0.004226, 0.0086356, 0.01276864, -0.00876396,
           -0.00735425, 0.00784997, 0.02065431, -0.00135458,
           -0.01251085, 0.01477734, -0.00388509, -0.00490721,
           -0.00977615, -0.00288586, 0.00975696, -0.01077098,
           0.01091359, -0.01245951, -0.00758745, -0.01204287,
           -0.01133357, -0.01852407, -0.00474335, -0.00561146,
           -0.00434557, 0.00163291, -0.02099208, 0.00508234]],

         [[-0.00190852, -0.0074805, 0.00409142, -0.00135967,
           0.0202753, 0.01276697, 0.00925621, -0.01709681,
           -0.00050737, 0.00789639, 0.01097574, -0.00114015,
           0.02379986, 0.00523815, -0.00768376, -0.00811588,
           0.00865658, 0.0067863, 0.00824494, -0.0022902,
           0.0145499, -0.01732569, -0.00702583, 0.00188879,
           0.01032395, 0.0056568, -0.0126687, 0.01342031,
           0.01845052, 0.02157152, 0.00420438, 0.02604849]]], dtype=np.float32
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    return {
        "rope_q_emb1_for_prefill": rope_q_emb1_for_prefill,
        "rope_k_emb1_for_prefill": rope_k_emb1_for_prefill,
        "rope_q_emb1_for_decode": rope_q_emb1_for_decode,
        "rope_k_emb1_for_decode": rope_k_emb1_for_decode,
        "llama3_rope_q_emb1_for_prefill": llama3_rope_q_emb1_for_prefill,
        "llama3_rope_k_emb1_for_prefill": llama3_rope_k_emb1_for_prefill,
        "llama3_rope_q_emb1_for_decode": llama3_rope_q_emb1_for_decode,
        "llama3_rope_k_emb1_for_decode": llama3_rope_k_emb1_for_decode,
        "yarn_rope_q_emb1_for_prefill": yarn_rope_q_emb1_for_prefill,
        "yarn_rope_k_emb1_for_prefill": yarn_rope_k_emb1_for_prefill,
        "yarn_rope_q_emb1_for_decode": yarn_rope_q_emb1_for_decode,
        "yarn_rope_k_emb1_for_decode": yarn_rope_k_emb1_for_decode,
    }


def get_gpu_data() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    rope_q_emb1_for_prefill = np.array(
        [[[-9.2316e-04, 7.3547e-03, -1.4404e-02, -6.6223e-03,
           -1.0071e-03, 2.1484e-02, 1.3916e-02, -2.0752e-03,
           7.6294e-03, 2.0447e-03, -1.0437e-02, -8.3008e-03,
           -1.0559e-02, -8.8882e-04, 1.5793e-03, -1.7944e-02,
           1.0620e-02, 3.7193e-04, -1.6403e-03, -3.2806e-03,
           1.5869e-02, -1.8845e-03, -1.8158e-03, 7.1716e-03,
           -7.0333e-06, 7.8125e-03, 8.6975e-04, 6.6376e-04,
           1.2024e-02, -6.1798e-04, -9.0332e-03, -1.2573e-02],
          [5.9814e-03, -1.3855e-02, 7.4463e-03, 5.1880e-03,
           -2.4261e-03, -5.8289e-03, -1.0147e-03, 5.7068e-03,
           8.7280e-03, -1.1963e-02, 1.9043e-02, -6.4087e-03,
           -1.1841e-02, 1.5503e-02, 1.2146e-02, 9.3384e-03,
           -3.1586e-03, 3.9062e-03, 1.2741e-03, -1.3184e-02,
           -6.9885e-03, -1.3000e-02, 9.0332e-03, 1.1841e-02,
           -1.0315e-02, -1.6327e-03, 3.9978e-03, -1.1658e-02,
           9.3994e-03, -1.0071e-02, -1.3916e-02, -8.4839e-03],
          [-3.7231e-03, -8.6060e-03, 1.2085e-02, 1.8835e-05,
           -1.4282e-02, 3.4332e-03, -1.4343e-02, 1.5320e-02,
           -1.2024e-02, -4.4556e-03, 1.9150e-03, -7.7820e-03,
           -9.8267e-03, 2.5330e-03, -1.5625e-02, 1.0376e-02,
           -1.2939e-02, -1.2573e-02, -1.4496e-03, -2.2095e-02,
           -1.7929e-04, -7.2937e-03, 3.3264e-03, -1.0834e-03,
           8.9722e-03, -8.9722e-03, 7.9346e-03, -2.3365e-05,
           -5.1880e-03, 2.1362e-03, -7.7515e-03, -5.9204e-03],
          [-1.5259e-02, -2.7008e-03, -4.5166e-03, -2.0874e-02,
           1.5320e-02, 1.7456e-02, -5.6763e-03, 8.4839e-03,
           5.4932e-03, -4.2114e-03, -2.0752e-03, 1.0620e-02,
           -1.6113e-02, 8.1787e-03, 1.2451e-02, -6.4087e-03,
           3.3875e-03, 7.6599e-03, 1.0071e-02, -4.6997e-03,
           -8.9111e-03, -2.8076e-03, -2.2827e-02, -8.1787e-03,
           6.7139e-03, -8.5449e-03, 3.7994e-03, 2.0264e-02,
           -5.4016e-03, 3.3112e-03, -1.5442e-02, 9.3994e-03]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    rope_k_emb1_for_prefill = np.array(
        [[[0.00934, -0.00656, -0.00964, -0.004395, -0.01251,
           0.005432, 0.0105, 0.001183, -0.01807, 0.01599,
           0.00479, 0.00232, -0.01611, -0.001129, 0.0004349,
           0.001045, 0.005768, -0.00842, -0.004852, -0.0076,
           -0.002304, -0.01831, 0.00958, 0.0047, 0.00757,
           -0.001625, 0.002838, -0.005005, 0.01575, -0.01251,
           -0.01367, -0.00135],
          [-0.00598, 0.013245, 0.007507, 0.00818, -0.00806,
           -0.009705, -0.0105, 0.00787, 0.00934, 0.00879,
           -0.00952, -0.01434, 0.0007744, 0.00464, 0.00203,
           0.00769, 0.005768, 0.02307, 0.01068, -0.015015,
           -0.00772, -0.002838, 0.01807, -0.006836, -0.003067,
           -0.002243, 0.00699, -0.00325, 0.002777, -0.013916,
           0.000908, 0.01172],
          [0.005676, -0.00232, 0.008606, 0.00842, 0.00406,
           0.006866, 0.0105, -0.007294, -0.006104, 0.0065,
           0.01709, -0.0011215, -0.010376, 0.01221, -0.00322,
           -0.00406, -0.0238, -0.001366, 0.00595, -0.01013,
           0.00879, -0.0105, -0.00641, -0.00995, -0.0094,
           -0.01532, -0.003937, -0.00464, -0.003601, 0.00135,
           -0.01733, 0.00421],
          [-0.001572, -0.006165, 0.003098, -0.001038, 0.01624,
           0.01099, 0.007782, -0.01416, -0.000496, 0.0065,
           0.009155, -0.0009613, 0.01965, 0.004333, -0.006348,
           -0.006714, 0.00714, 0.005646, 0.006958, -0.0019455,
           0.012695, -0.0141, -0.005676, 0.001373, 0.008545,
           0.0047, -0.0105, 0.01111, 0.01526, 0.01782,
           0.003479, 0.0216]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    rope_q_emb1_for_decode = np.array(
        [[[0.0001831, -0.001068, -0.00763, -0.00299, -0.01221,
           0.00647, 0.01019, 0.001099, -0.01819, 0.01599,
           0.00479, 0.002335, -0.01611, -0.0011215, 0.0004387,
           0.001045, 0.010925, -0.01062, -0.0076, -0.00824,
           -0.00354, -0.01807, 0.00989, 0.00473, 0.007385,
           -0.0015335, 0.002853, -0.005005, 0.01575, -0.01251,
           -0.01367, -0.00135]],

         [[-0.00598, 0.013245, 0.007507, 0.00818, -0.00806,
           -0.009705, -0.0105, 0.00787, 0.00934, 0.00879,
           -0.00952, -0.01434, 0.0007744, 0.00464, 0.00203,
           0.00769, 0.005768, 0.02307, 0.01068, -0.015015,
           -0.00772, -0.002838, 0.01807, -0.006836, -0.003067,
           -0.002243, 0.00699, -0.00325, 0.002777, -0.013916,
           0.000908, 0.01172]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    rope_k_emb1_for_decode = np.array(
        [[[0.02307, -0.001244, 0.006348, 0.01007, 0.003143,
           0.007446, 0.01068, -0.00711, -0.006012, 0.00659,
           0.01709, -0.001114, -0.010376, 0.01221, -0.00322,
           -0.00406, -0.00806, -0.002396, 0.0083, -0.008484,
           0.009155, -0.01013, -0.006073, -0.01007, -0.00946,
           -0.01526, -0.003876, -0.00464, -0.003616, 0.001358,
           -0.01733, 0.00421]],

         [[-0.001572, -0.006165, 0.003098, -0.001038, 0.01624,
           0.01099, 0.007782, -0.01416, -0.000496, 0.0065,
           0.009155, -0.0009613, 0.01965, 0.004333, -0.006348,
           -0.006714, 0.00714, 0.005646, 0.006958, -0.0019455,
           0.012695, -0.0141, -0.005676, 0.001373, 0.008545,
           0.0047, -0.0105, 0.01111, 0.01526, 0.01782,
           0.003479, 0.0216]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    llama3_rope_q_emb1_for_prefill = np.array(
        [[[-9.2316e-04, 7.3547e-03, -1.4404e-02, -6.6223e-03,
           -1.0071e-03, 2.1484e-02, 1.3916e-02, -2.0752e-03,
           7.6294e-03, 2.0447e-03, -1.0437e-02, -8.3008e-03,
           -1.0559e-02, -8.8882e-04, 1.5793e-03, -1.7944e-02,
           1.0620e-02, 3.7193e-04, -1.6403e-03, -3.2806e-03,
           1.5869e-02, -1.8845e-03, -1.8158e-03, 7.1716e-03,
           -7.0333e-06, 7.8125e-03, 8.6975e-04, 6.6376e-04,
           1.2024e-02, -6.1798e-04, -9.0332e-03, -1.2573e-02],
          [5.9814e-03, -1.3855e-02, 7.4463e-03, 5.1880e-03,
           -2.4261e-03, -5.8289e-03, -8.9645e-04, 5.8594e-03,
           8.6060e-03, -1.1963e-02, 1.9043e-02, -6.4392e-03,
           -1.1841e-02, 1.5503e-02, 1.2146e-02, 9.3384e-03,
           -3.1586e-03, 3.9062e-03, 1.2741e-03, -1.3184e-02,
           -6.9885e-03, -1.3000e-02, 9.0332e-03, 1.1719e-02,
           -1.0376e-02, -1.5717e-03, 3.9368e-03, -1.1658e-02,
           9.3994e-03, -1.0071e-02, -1.3916e-02, -8.4839e-03],
          [-3.7231e-03, -8.6060e-03, 1.2085e-02, 1.8835e-05,
           -1.4282e-02, 3.4332e-03, -1.4343e-02, 1.5320e-02,
           -1.2024e-02, -4.4556e-03, 1.9150e-03, -7.7820e-03,
           -9.8267e-03, 2.5330e-03, -1.5625e-02, 1.0376e-02,
           -1.2939e-02, -1.2573e-02, -1.4496e-03, -2.2095e-02,
           -1.7929e-04, -7.2937e-03, 3.3264e-03, -1.0834e-03,
           8.9722e-03, -8.9722e-03, 7.9346e-03, -2.3365e-05,
           -5.1880e-03, 2.1362e-03, -7.7515e-03, -5.9204e-03],
          [-1.5259e-02, -2.7008e-03, -4.5166e-03, -2.0874e-02,
           1.5320e-02, 1.7456e-02, -5.9814e-03, 8.4229e-03,
           5.5542e-03, -4.2725e-03, -2.0599e-03, 1.0681e-02,
           -1.6113e-02, 8.1787e-03, 1.2451e-02, -6.4087e-03,
           3.3875e-03, 7.6599e-03, 1.0071e-02, -4.6997e-03,
           -8.9111e-03, -2.8076e-03, -2.2705e-02, -8.2397e-03,
           6.6528e-03, -8.5449e-03, 3.7994e-03, 2.0264e-02,
           -5.3711e-03, 3.3112e-03, -1.5442e-02, 9.3994e-03]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    llama3_rope_k_emb1_for_prefill = np.array(
        [[[0.00934, -0.00656, -0.00964, -0.004395, -0.01251,
           0.005432, 0.0105, 0.001183, -0.01807, 0.01599,
           0.00479, 0.00232, -0.01611, -0.001129, 0.0004349,
           0.001045, 0.005768, -0.00842, -0.004852, -0.0076,
           -0.002304, -0.01831, 0.00958, 0.0047, 0.00757,
           -0.001625, 0.002838, -0.005005, 0.01575, -0.01251,
           -0.01367, -0.00135],
          [-0.00598, 0.013245, 0.007507, 0.00818, -0.00806,
           -0.009705, -0.010254, 0.007782, 0.00928, 0.00879,
           -0.00952, -0.01434, 0.000778, 0.00464, 0.00203,
           0.00769, 0.005768, 0.02307, 0.01068, -0.015015,
           -0.00772, -0.002838, 0.01831, -0.006958, -0.003143,
           -0.002289, 0.00702, -0.00322, 0.002777, -0.013916,
           0.000908, 0.01172],
          [0.005676, -0.00232, 0.008606, 0.00842, 0.00406,
           0.006866, 0.0105, -0.007294, -0.006104, 0.0065,
           0.01709, -0.0011215, -0.010376, 0.01221, -0.00322,
           -0.00406, -0.0238, -0.001366, 0.00595, -0.01013,
           0.00879, -0.0105, -0.00641, -0.00995, -0.0094,
           -0.01532, -0.003937, -0.00464, -0.003601, 0.00135,
           -0.01733, 0.00421],
          [-0.001572, -0.006165, 0.003098, -0.001038, 0.01624,
           0.01099, 0.00772, -0.01416, -0.0004215, 0.00653,
           0.009094, -0.000946, 0.01965, 0.004333, -0.006348,
           -0.006714, 0.00714, 0.005646, 0.006958, -0.0019455,
           0.012695, -0.0141, -0.005768, 0.001564, 0.008545,
           0.00467, -0.0105, 0.01111, 0.01526, 0.01782,
           0.003479, 0.0216]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    llama3_rope_q_emb1_for_decode = np.array(
        [[[0.0001831, -0.001068, -0.00763, -0.00299, -0.01221,
           0.00647, 0.010315, 0.00116, -0.01807, 0.01599,
           0.00479, 0.00232, -0.01611, -0.001129, 0.0004349,
           0.001045, 0.010925, -0.01062, -0.0076, -0.00824,
           -0.00354, -0.01807, 0.009766, 0.0047, 0.007538,
           -0.001617, 0.002838, -0.005005, 0.01575, -0.01251,
           -0.01367, -0.00135]],

         [[-0.00598, 0.013245, 0.007507, 0.00818, -0.00806,
           -0.009705, -0.010254, 0.007782, 0.00928, 0.00879,
           -0.00952, -0.01434, 0.000778, 0.00464, 0.00203,
           0.00769, 0.005768, 0.02307, 0.01068, -0.015015,
           -0.00772, -0.002838, 0.01831, -0.006958, -0.003143,
           -0.002289, 0.00702, -0.00322, 0.002777, -0.013916,
           0.000908, 0.01172]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    llama3_rope_k_emb1_for_decode = np.array(
        [[[0.02307, -0.001244, 0.006348, 0.01007, 0.003143,
           0.007446, 0.01062, -0.007263, -0.006104, 0.0065,
           0.01709, -0.0011215, -0.010376, 0.01221, -0.00322,
           -0.00406, -0.00806, -0.002396, 0.0083, -0.008484,
           0.009155, -0.01013, -0.006226, -0.01001, -0.0094,
           -0.01532, -0.003937, -0.00464, -0.003601, 0.00135,
           -0.01733, 0.00421]],

         [[-0.001572, -0.006165, 0.003098, -0.001038, 0.01624,
           0.01099, 0.00772, -0.01416, -0.0004215, 0.00653,
           0.009094, -0.000946, 0.01965, 0.004333, -0.006348,
           -0.006714, 0.00714, 0.005646, 0.006958, -0.0019455,
           0.012695, -0.0141, -0.005768, 0.001564, 0.008545,
           0.00467, -0.0105, 0.01111, 0.01526, 0.01782,
           0.003479, 0.0216]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    yarn_rope_q_emb1_for_prefill = np.array(
        [[[-1.1215e-03, 8.9111e-03, -1.7456e-02, -7.9956e-03,
           -1.2207e-03, 2.6001e-02, 1.6846e-02, -2.5177e-03,
           9.2163e-03, 2.4719e-03, -1.2634e-02, -1.0071e-02,
           -1.2756e-02, -1.0757e-03, 1.9150e-03, -2.1729e-02,
           1.2878e-02, 4.5013e-04, -1.9836e-03, -3.9673e-03,
           1.9165e-02, -2.2888e-03, -2.1973e-03, 8.6670e-03,
           -8.5235e-06, 9.4604e-03, 1.0529e-03, 8.0490e-04,
           1.4587e-02, -7.4768e-04, -1.0925e-02, -1.5198e-02],
          [7.2021e-03, -1.6846e-02, 9.0942e-03, 5.5542e-03,
           -3.2349e-03, -7.5073e-03, -1.0071e-03, 7.1106e-03,
           1.0437e-02, -1.4465e-02, 2.3071e-02, -7.8125e-03,
           -1.4343e-02, 1.8799e-02, 1.4709e-02, 1.1292e-02,
           -3.8147e-03, 4.6997e-03, 1.1902e-03, -1.6113e-02,
           -8.3008e-03, -1.5564e-02, 1.0925e-02, 1.4221e-02,
           -1.2573e-02, -1.8997e-03, 4.7607e-03, -1.4099e-02,
           1.1353e-02, -1.2207e-02, -1.6846e-02, -1.0254e-02],
          [-4.5166e-03, -1.0437e-02, 1.4648e-02, 2.2769e-05,
           -1.7334e-02, 4.1504e-03, -1.7334e-02, 1.8555e-02,
           -1.4587e-02, -5.4016e-03, 2.3193e-03, -9.3994e-03,
           -1.1902e-02, 3.0670e-03, -1.8921e-02, 1.2573e-02,
           -1.5625e-02, -1.5198e-02, -1.7548e-03, -2.6733e-02,
           -2.1744e-04, -8.8501e-03, 4.0283e-03, -1.3123e-03,
           1.0864e-02, -1.0864e-02, 9.5825e-03, -2.8253e-05,
           -6.2866e-03, 2.5940e-03, -9.3994e-03, -7.1716e-03],
          [-1.8433e-02, -3.3112e-03, -5.0049e-03, -2.5391e-02,
           1.7944e-02, 2.0996e-02, -7.4158e-03, 1.0193e-02,
           6.7139e-03, -5.1880e-03, -2.4872e-03, 1.2939e-02,
           -1.9531e-02, 9.8877e-03, 1.5076e-02, -7.7515e-03,
           4.0894e-03, 9.2773e-03, 1.2451e-02, -4.5471e-03,
           -1.1414e-02, -3.9978e-03, -2.7466e-02, -1.0010e-02,
           8.0566e-03, -1.0376e-02, 4.6082e-03, 2.4536e-02,
           -6.5002e-03, 3.9978e-03, -1.8677e-02, 1.1353e-02]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    yarn_rope_k_emb1_for_prefill = np.array(
        [[[0.01129, -0.007935, -0.01166, -0.00531, -0.01514,
           0.00659, 0.012695, 0.001434, -0.02185, 0.01941,
           0.0058, 0.002808, -0.01953, -0.001366, 0.0005264,
           0.0012665, 0.00699, -0.01019, -0.00589, -0.00922,
           -0.002792, -0.02222, 0.0116, 0.005676, 0.009155,
           -0.001968, 0.003433, -0.006073, 0.01904, -0.01514,
           -0.0166, -0.001633],
          [-0.007233, 0.01587, 0.00964, 0.00903, -0.01001,
           -0.01184, -0.01221, 0.00946, 0.01123, 0.01062,
           -0.011536, -0.01733, 0.000942, 0.005615, 0.002457,
           0.00934, 0.00702, 0.02795, 0.01257, -0.01855,
           -0.00891, -0.003113, 0.02222, -0.00842, -0.003815,
           -0.002777, 0.008484, -0.003906, 0.003357, -0.01685,
           0.001099, 0.01416],
          [0.006866, -0.002808, 0.01044, 0.01019, 0.004913,
           0.0083, 0.012695, -0.00885, -0.007385, 0.00787,
           0.02075, -0.001358, -0.01257, 0.01477, -0.003906,
           -0.004913, -0.02881, -0.001656, 0.007202, -0.01227,
           0.01062, -0.012695, -0.00775, -0.012024, -0.01135,
           -0.01855, -0.00476, -0.005615, -0.004364, 0.001633,
           -0.021, 0.005096],
          [-0.001892, -0.007477, 0.00412, -0.001358, 0.02014,
           0.01276, 0.00928, -0.01709, -0.0005074, 0.007935,
           0.01099, -0.001144, 0.0238, 0.00525, -0.00769,
           -0.00812, 0.00867, 0.006805, 0.00824, -0.002289,
           0.01459, -0.01733, -0.00705, 0.001892, 0.010376,
           0.005646, -0.012695, 0.01343, 0.01843, 0.0216,
           0.00421, 0.02612]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    yarn_rope_q_emb1_for_decode = np.array(
        [[[0.0002441, -0.001282, -0.00964, -0.00403, -0.01489,
           0.007202, 0.01257, 0.001411, -0.02185, 0.01941,
           0.0058, 0.002808, -0.01953, -0.001366, 0.0005264,
           0.0012665, 0.013245, -0.01282, -0.00885, -0.009766,
           -0.003723, -0.02197, 0.01172, 0.005676, 0.009155,
           -0.001953, 0.003433, -0.006073, 0.01904, -0.01514,
           -0.0166, -0.001633]],

         [[-0.007233, 0.01587, 0.00964, 0.00903, -0.01001,
           -0.01184, -0.01221, 0.00946, 0.01123, 0.01062,
           -0.011536, -0.01733, 0.000942, 0.005615, 0.002457,
           0.00934, 0.00702, 0.02795, 0.01257, -0.01855,
           -0.00891, -0.003113, 0.02222, -0.00842, -0.003815,
           -0.002777, 0.008484, -0.003906, 0.003357, -0.01685,
           0.001099, 0.01416]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    yarn_rope_k_emb1_for_decode = np.array(
        [[[0.02783, -0.001495, 0.00806, 0.01172, 0.00421,
           0.00867, 0.01282, -0.00879, -0.007385, 0.00787,
           0.02075, -0.001358, -0.01257, 0.01477, -0.003906,
           -0.004913, -0.009766, -0.0029, 0.009766, -0.01074,
           0.010864, -0.01245, -0.0076, -0.012085, -0.01135,
           -0.01855, -0.00476, -0.005615, -0.004364, 0.001633,
           -0.021, 0.005096]],

         [[-0.001892, -0.007477, 0.00412, -0.001358, 0.02014,
           0.01276, 0.00928, -0.01709, -0.0005074, 0.007935,
           0.01099, -0.001144, 0.0238, 0.00525, -0.00769,
           -0.00812, 0.00867, 0.006805, 0.00824, -0.002289,
           0.01459, -0.01733, -0.00705, 0.001892, 0.010376,
           0.005646, -0.012695, 0.01343, 0.01843, 0.0216,
           0.00421, 0.02612]]], dtype=np.float16
    ).reshape(-1, KV_CHANNELS * NUM_HEAD)

    return {
        "rope_q_emb1_for_prefill": rope_q_emb1_for_prefill,
        "rope_k_emb1_for_prefill": rope_k_emb1_for_prefill,
        "rope_q_emb1_for_decode": rope_q_emb1_for_decode,
        "rope_k_emb1_for_decode": rope_k_emb1_for_decode,
        "llama3_rope_q_emb1_for_prefill": llama3_rope_q_emb1_for_prefill,
        "llama3_rope_k_emb1_for_prefill": llama3_rope_k_emb1_for_prefill,
        "llama3_rope_q_emb1_for_decode": llama3_rope_q_emb1_for_decode,
        "llama3_rope_k_emb1_for_decode": llama3_rope_k_emb1_for_decode,
        "yarn_rope_q_emb1_for_prefill": yarn_rope_q_emb1_for_prefill,
        "yarn_rope_k_emb1_for_prefill": yarn_rope_k_emb1_for_prefill,
        "yarn_rope_q_emb1_for_decode": yarn_rope_q_emb1_for_decode,
        "yarn_rope_k_emb1_for_decode": yarn_rope_k_emb1_for_decode,
    }


GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_data()
